"""Monitor queues and scale up/down services."""
import argparse
import functools
from importlib import resources
import os
import pathlib
import time

from cki_lib import messagequeue
from cki_lib import metrics
from cki_lib import misc
from cki_lib import yaml
from cki_lib.logger import get_logger
from kubernetes import client
from kubernetes import config as k8s_config
from kubernetes import dynamic

LOGGER = get_logger('cki_tools.autoscaler')


class Application:
    """Application to monitor."""

    def __init__(self, name, config, queue):
        """Initialize service."""
        self.name = name
        self.config = config
        self.queue = queue
        self.queue_name = self.config['queue_prefix'] + self.name
        deploy_suffix = '' if misc.is_production() else '-staging'
        self.deploy_name = self.name + deploy_suffix

        k8s_config.load_incluster_config()
        self.deploy_client = dynamic.client.DynamicClient(client.ApiClient()).resources.get(
            api_version='apps/v1', kind='Deployment')

    @functools.cached_property
    def namespace(self) -> str:
        """Return the in-cluster namespace."""
        return pathlib.Path(
            '/var/run/secrets/kubernetes.io/serviceaccount/namespace'
        ).read_text(encoding='utf8')

    def evaluate(self, messages, replicas):
        """Decide if scaling up or down is necessary."""
        new_replicas = (
            min(self.config['replicas_max'], (
                max(self.config['replicas_min'], (
                    (
                        max(0, messages - self.config['message_baseline'])
                        +
                        self.config['messages_per_replica'] - 1
                    )
                    //
                    self.config['messages_per_replica']
                ))
            ))
        )

        if new_replicas > replicas:
            LOGGER.debug('%s: should scale up', self.name)
            new_replicas = min(new_replicas, replicas + self.config['max_scale_up_step'])
        elif new_replicas < replicas:
            LOGGER.debug('%s: should scale down', self.name)
            new_replicas = max(new_replicas, replicas - self.config['max_scale_down_step'])

        return new_replicas

    def check(self):
        """Run the checks and scale up/down if necessary."""
        try:
            deploy = self.deploy_client.get(namespace=self.namespace, name=self.deploy_name)
        except Exception:  # pylint: disable=broad-except
            LOGGER.info('deploy/%s does not exist', self.deploy_name)
            return

        with self.queue.connect() as channel:
            messages = channel.queue_declare(self.queue_name, passive=True).method.message_count
        replicas = deploy.spec.replicas
        LOGGER.debug('%s: %d messages, %d replicas', self.name, messages, replicas)

        if (new_replicas := self.evaluate(messages, replicas)) == replicas:
            LOGGER.debug('%s: replica count not changed', self.name)
            return

        LOGGER.info('scaling %s application=%s messages=%d replicas=%d->%d',
                    'up' if new_replicas > replicas else 'down',
                    self.name, messages, replicas, new_replicas)
        if misc.is_production_or_staging():
            self.deploy_client.server_side_apply(
                namespace=self.namespace, name=f'{self.deploy_name}/scale',
                field_manager='autoscaler', force_conflicts=True,
                body={
                    'apiVersion':  'autoscaling/v1', 'kind': 'Scale',
                    'spec': {'replicas': new_replicas},
                },
            )


def main(args: list[str] | None = None) -> None:
    """CLI Interface."""
    parser = argparse.ArgumentParser()
    parser.parse_args(args)

    metrics.prometheus_init()

    schema_path = resources.files(__package__) / 'schema.yml'
    config = yaml.load(
        schema_path=schema_path,
        contents=os.environ.get('AUTOSCALER_CONFIG'),
        file_path=os.environ.get('AUTOSCALER_CONFIG_PATH'),
    )

    refresh_period = misc.get_env_int('REFRESH_PERIOD', 30)
    queue = messagequeue.MessageQueue(keepalive_s=2 * refresh_period)
    services = [Application(n, c, queue) for n, c in config.items()]

    while True:
        for service in services:
            service.check()

        time.sleep(refresh_period)
